[PyTorch] Backwards compatible single param checkpointing in GroupedLinear#2761
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile SummaryThis PR adds backwards-compatible checkpoint loading to Key changes:
Issues found:
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["User calls GroupedLinear.load_state_dict(state_dict, strict, assign)"] --> B["Make shallow copy: state_dict_copy"]
B --> C["_remap_grouped_weight_state_dict_keys(state_dict_copy, prefix='')"]
C --> D{single_grouped_parameter?}
D -- "True\n(want 'weight')" --> E{has weight0..N\nbut no weight?}
E -- Yes --> F["Stack weight0..N → plain torch.Tensor\nInsert as 'weight'"]
E -- No --> G["Drop redundant per-GEMM keys"]
D -- "False\n(want weight0..N)" --> H{has 'weight'\nbut no weight0..N?}
H -- Yes --> I{is GroupedTensor?}
I -- Yes --> J["split_into_quantized_tensors()\nor use .quantized_tensors"]
I -- No --> K["unbind(dim=0)"]
J --> L["Insert weight0..N as plain tensors"]
K --> L
H -- No --> M["Drop redundant 'weight' key"]
F --> N["super().load_state_dict(state_dict_copy)"]
G --> N
L --> N
M --> N
N --> O["PyTorch recursive loading\n→ calls _load_from_state_dict"]
O --> P["_remap_grouped_weight_state_dict_keys again\n(idempotent, no-op)"]
P --> Q["super()._load_from_state_dict()"]
Q --> R{assign=True?}
R -- "No (default)" --> S["param.copy_(state_dict_value)\nGroupedTensor __torch_dispatch__ handles copy"]
R -- "Yes ⚠️" --> T["setattr(module, 'weight', plain_tensor)\nGroupedTensor replaced by plain tensor — BUG"]
Last reviewed commit: ebd23b9 |
| def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): | ||
| """Load state dict with grouped-weight format compatibility.""" | ||
| state_dict_copy = state_dict.copy() | ||
| metadata = getattr(state_dict, "_metadata", None) | ||
| if metadata is not None: | ||
| state_dict_copy._metadata = metadata | ||
| self._remap_grouped_weight_state_dict_keys(state_dict_copy, prefix="") | ||
| return super().load_state_dict(state_dict_copy, strict=strict, assign=assign) |
There was a problem hiding this comment.
Double remapping of weight keys
_remap_grouped_weight_state_dict_keys is applied twice whenever GroupedLinear.load_state_dict is the entry point:
- Explicitly in
load_state_dict(line 907). - Again inside
GroupedLinear._load_from_state_dict(line 914), which PyTorch'ssuper().load_state_dict()invokes internally as part of its recursive loading loop.
The second call is idempotent — after the first remap the state dict is already in the expected format, so the second remap is a no-op — but the redundancy is a maintenance hazard: a future change that makes the remap non-idempotent could silently introduce data corruption (e.g. double-stacking weights).
A straightforward fix is to skip the remap inside load_state_dict and let _load_from_state_dict handle it exclusively (which already covers the nested-module case). The copy is still needed to avoid mutating the caller's dict, so it should be preserved:
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
"""Load state dict with grouped-weight format compatibility."""
state_dict_copy = state_dict.copy()
metadata = getattr(state_dict, "_metadata", None)
if metadata is not None:
state_dict_copy._metadata = metadata
# Key remapping is performed in _load_from_state_dict which PyTorch
# calls internally; no need to remap again here.
return super().load_state_dict(state_dict_copy, strict=strict, assign=assign)This keeps the copy (protecting the caller's dict) and relies on _load_from_state_dict for the single, canonical remap path in all cases.
| torch.save(src.state_dict(), ckpt_path) | ||
| del src | ||
|
|
||
| src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
There was a problem hiding this comment.
weights_only=False enables arbitrary pickle execution
torch.load(..., weights_only=False) deserialises the file using Python's pickle module, which executes arbitrary code embedded in the file. PyTorch 2.x already emits a FutureWarning for this pattern and the default will flip to True in a future release.
For the multi-to-single test (test_grouped_linear_load_state_dict_multi_to_single_param) the source model uses single_grouped_parameter=False, so all saved tensors are plain torch.Tensor objects — weights_only=True should work fine there.
| src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) |
The same concern applies to line 540 in test_grouped_linear_load_state_dict_single_to_multi_param. For that test the saved weight is a GroupedTensor subclass, which may require weights_only=False to deserialise; if so, the incompatibility should be documented with an inline comment explaining why weights_only=True cannot be used.
|
/te-ci pytorch |
| expected_weights = [getattr(src, f"weight{i}").detach().clone() for i in range(num_gemms)] | ||
| ckpt_path = tmp_path / "grouped_linear_per_gemm.pt" | ||
| torch.save(src.state_dict(), ckpt_path) | ||
| del src |
There was a problem hiding this comment.
Should we also add test case for quantized_model_init(mxfp8)? Shouldnt be a blocker for this PR though.
| def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): | ||
| """Load state dict with grouped-weight format compatibility.""" | ||
| state_dict_copy = state_dict.copy() | ||
| metadata = getattr(state_dict, "_metadata", None) | ||
| if metadata is not None: | ||
| state_dict_copy._metadata = metadata | ||
| self._remap_grouped_weight_state_dict_keys(state_dict_copy, prefix="") | ||
| return super().load_state_dict(state_dict_copy, strict=strict, assign=assign) |
There was a problem hiding this comment.
assign=True replaces GroupedTensor with a plain tensor
When assign=True is passed to load_state_dict and the multi-to-single conversion is active (single_grouped_parameter=True, checkpoint has weight0..N), _remap_grouped_weight_state_dict_keys writes a plain torch.Tensor (from torch.stack) into state_dict_copy["weight"]. PyTorch's assign=True path then calls setattr(module, "weight", plain_tensor) instead of param.copy_(plain_tensor), so the GroupedTensor parameter is silently replaced by a plain tensor. Any subsequent forward pass that calls self.weight.split_into_quantized_tensors() or relies on the GroupedTensor.__torch_dispatch__ mechanism will crash or silently compute incorrect results.
A fix is to either document that assign=True is unsupported for cross-format loading, or reconstruct a proper GroupedTensor inside the remap helper when the target format is single_grouped_parameter=True:
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
"""Load state dict with grouped-weight format compatibility."""
if assign:
warnings.warn(
"GroupedLinear.load_state_dict with assign=True does not support "
"cross-format checkpoint loading. Use assign=False (default).",
UserWarning,
)
state_dict_copy = state_dict.copy()
metadata = getattr(state_dict, "_metadata", None)
if metadata is not None:
state_dict_copy._metadata = metadata
self._remap_grouped_weight_state_dict_keys(state_dict_copy, prefix="")
return super().load_state_dict(state_dict_copy, strict=strict, assign=assign)| if not has_per_gemm_weights and has_grouped_weight: | ||
| grouped_weight = state_dict.pop(grouped_weight_key) | ||
| if hasattr(grouped_weight, "split_into_quantized_tensors"): | ||
| grouped_members = grouped_weight.quantized_tensors | ||
| if grouped_members is None: | ||
| grouped_members = grouped_weight.split_into_quantized_tensors() | ||
| per_gemm_weights = [ | ||
| ( | ||
| weight.dequantize() | ||
| if isinstance(weight, QuantizedTensorStorage) | ||
| else weight | ||
| ) | ||
| for weight in grouped_members | ||
| ] | ||
| else: | ||
| grouped_weight = ( | ||
| grouped_weight.dequantize() | ||
| if isinstance(grouped_weight, QuantizedTensorStorage) | ||
| else grouped_weight | ||
| ) | ||
| per_gemm_weights = list(grouped_weight.unbind(dim=0)) | ||
| for i, weight in enumerate(per_gemm_weights): | ||
| state_dict[f"{prefix}weight{i}"] = weight |
There was a problem hiding this comment.
No validation of GEMM count after splitting
When splitting a grouped checkpoint into per-GEMM weights, neither the split_into_quantized_tensors() path nor the unbind(dim=0) path validates that the number of recovered tensors equals self.num_gemms. If the checkpoint was created with a different number of GEMMs (e.g., num_gemms=5 saved, num_gemms=3 loaded), the remap will silently inject weight0..4 into the state dict. With strict=True, PyTorch will then report weight3 and weight4 as unexpected keys — but the diagnostic message gives no hint that the root cause is a GEMM-count mismatch.
Adding an explicit early check here improves debuggability:
if hasattr(grouped_weight, "split_into_quantized_tensors"):
grouped_members = grouped_weight.quantized_tensors
if grouped_members is None:
grouped_members = grouped_weight.split_into_quantized_tensors()
if len(grouped_members) != self.num_gemms:
raise ValueError(
f"Checkpoint grouped weight contains {len(grouped_members)} GEMMs "
f"but this module was configured with num_gemms={self.num_gemms}."
)
...
else:
per_gemm_weights = list(grouped_weight.unbind(dim=0))
if len(per_gemm_weights) != self.num_gemms:
raise ValueError(
f"Checkpoint stacked weight has {len(per_gemm_weights)} slices along dim=0 "
f"but this module was configured with num_gemms={self.num_gemms}."
)| def _load_from_state_dict( | ||
| self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | ||
| ): | ||
| """Load state, including compatibility across grouped-weight checkpoint formats.""" | ||
| self._remap_grouped_weight_state_dict_keys(state_dict, prefix) | ||
|
|
||
| super()._load_from_state_dict( | ||
| state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | ||
| ) |
There was a problem hiding this comment.
_load_from_state_dict mutates the shared state dict in-place
When GroupedLinear is used as a submodule, PyTorch passes the same state_dict object (with the full prefix tree) to every module's _load_from_state_dict. This override calls _remap_grouped_weight_state_dict_keys, which modifies that shared dict in-place — popping old keys and inserting new ones (e.g. swapping "parent.grouped.weight" out for "parent.grouped.weight0..N").
While the key operations are scoped by prefix and don't touch other modules' keys, the mutation is a side-effect that:
- Permanently alters the caller's state dict after the fact (the user may not expect their dict to be modified when loading a submodule).
- Interacts unexpectedly with the
unexpected_keysaccounting in PyTorch's base_load_from_state_dictif the newly injected keys are not all consumed.
A defensive pattern is to work on a shallow copy of only the module's relevant key-space, similar to what load_state_dict already does at the top level. At minimum, adding a comment here that the mutation is intentional and scoped would reduce the maintenance burden.
| state_dict[grouped_weight_key] = torch.stack(per_gemm_weights, dim=0) | ||
| elif has_grouped_weight: | ||
| # Drop any redundant per-GEMM keys to avoid strict-load unexpected-key errors. | ||
| for key in per_gemm_weight_keys: |
There was a problem hiding this comment.
We might need this even for TE sequential checkpointing right? Maybe putting it in utils and using it in both places make sense to avoid code duplication
…Linear` (#2761) * Load multi-param checkpoint from single-param config in GroupedLinear Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Multi-param to single param case Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Multi-param to single param case Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Better varnames Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
…Linear` (NVIDIA#2761) * Load multi-param checkpoint from single-param config in GroupedLinear Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Multi-param to single param case Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Multi-param to single param case Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Better varnames Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Description
GroupedLinearmodule supports either a single parameter registration viaGroupedTensoror one param per expert. This PR supports checkpointing loading compatibility across those options.Type of change
Changes
Checklist: